"""
Phase 2: Descriptive Statistics — Safe Asset Cliff
====================================================
Tables: transition matrix, summary stats by rating category, timeline
Figures: N safe issuers over time, OADR vs rating scatter

Output: table1_transitions.md, table2_summary_by_category.md,
        table3_downgrade_timeline.md, figure1_safe_count.png, figure2_scatter.png
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
FIGURES_DIR = PROJECT_DIR / "output" / "figures"

sys.path.insert(0, str(PROJECT_DIR.parent / "safe_assets" / "scripts"))
from phase1_data_assembly import RATING_SCALE

REVERSE_SCALE = {v: k for k, v in RATING_SCALE.items()}

CATEGORY_LABELS = {4: 'AAA', 3: 'AA+', 2: 'AA', 1: 'AA-', 0: 'Below AA-'}


def write_markdown_table(path, title, headers, rows, notes=None):
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


def table1_transitions(df):
    """Table 1: Rating transition matrix."""
    print("\n  Building Table 1: Rating Transition Matrix ...")

    # Year-on-year transitions
    df = df.sort_values(['iso3', 'year'])
    df['cat_lag'] = df.groupby('iso3')['rating_category'].shift(1)

    trans = df.dropna(subset=['rating_category', 'cat_lag'])
    trans['cat_from'] = trans['cat_lag'].astype(int)
    trans['cat_to'] = trans['rating_category'].astype(int)

    cats = [4, 3, 2, 1, 0]
    headers = ['From \\ To'] + [CATEGORY_LABELS[c] for c in cats] + ['Total']
    rows = []
    for f in cats:
        row = [CATEGORY_LABELS[f]]
        total = 0
        for t in cats:
            n = ((trans['cat_from'] == f) & (trans['cat_to'] == t)).sum()
            row.append(str(n))
            total += n
        row.append(str(total))
        rows.append(row)

    write_markdown_table(
        TABLES_DIR / "table1_transitions.md",
        "Table 1: Rating Category Transition Matrix (Annual, 1990-2024)",
        headers, rows,
        notes="Categories: AAA (21), AA+ (20), AA (19), AA- (18), Below AA- (<18). 31 rated countries."
    )


def table2_summary_by_category(df):
    """Table 2: Summary stats by rating category."""
    print("\n  Building Table 2: Summary by Rating Category ...")

    vars_to_summarize = {
        'old_dep': 'OADR',
        'govt_debt_gdp': 'Debt/GDP',
        'exp_rev_gap': 'Exp-Rev Gap',
        'r_minus_g': 'r-g',
        'rgdp_growth': 'GDP Growth',
        'inflation': 'Inflation',
    }

    headers = ['Variable'] + [CATEGORY_LABELS[c] for c in [4, 3, 2, 1, 0]] + ['All']
    rows = []

    for var, label in vars_to_summarize.items():
        if var not in df.columns:
            continue
        row = [label]
        for cat in [4, 3, 2, 1, 0]:
            sub = df[df['rating_category'] == cat][var].dropna()
            if len(sub) > 0:
                row.append(f"{sub.mean():.2f}")
            else:
                row.append("-")
        # All
        all_vals = df[var].dropna()
        row.append(f"{all_vals.mean():.2f}")
        rows.append(row)

    # Add N obs row
    row = ['N obs']
    for cat in [4, 3, 2, 1, 0]:
        n = (df['rating_category'] == cat).sum()
        row.append(str(n))
    row.append(str(len(df)))
    rows.append(row)

    # N countries
    row = ['N countries']
    for cat in [4, 3, 2, 1, 0]:
        n = df[df['rating_category'] == cat]['iso3'].nunique()
        row.append(str(n))
    row.append(str(df['iso3'].nunique()))
    rows.append(row)

    write_markdown_table(
        TABLES_DIR / "table2_summary_by_category.md",
        "Table 2: Summary Statistics by Rating Category",
        headers, rows,
        notes="Means reported. OADR as proportion (0-1). Fiscal variables as % of GDP."
    )


def table3_downgrade_timeline(df):
    """Table 3: Timeline of downgrade events with context."""
    print("\n  Building Table 3: Downgrade Timeline ...")

    events = df[df['downgrade_any'] == 1].copy()
    events = events.sort_values(['year', 'iso3'])

    headers = ['Country', 'Year', 'From', 'To', 'Notches', 'OADR',
               'Debt/GDP', 'Exp-Rev Gap', 'Lost Safe']
    rows = []

    for _, e in events.iterrows():
        from_r = REVERSE_SCALE.get(int(e['rating_lag']), '?') if pd.notna(e['rating_lag']) else '?'
        to_r = REVERSE_SCALE.get(int(e['rating_numeric']), '?')
        oadr = f"{e['old_dep']:.3f}" if pd.notna(e.get('old_dep')) else '-'
        debt = f"{e['govt_debt_gdp']:.1f}" if pd.notna(e.get('govt_debt_gdp')) else '-'
        gap = f"{e['exp_rev_gap']:.1f}" if pd.notna(e.get('exp_rev_gap')) else '-'
        lost = 'Yes' if e['lost_safe'] == 1 else ''
        rows.append([
            e['iso3'], str(int(e['year'])), from_r, to_r,
            str(int(e['downgrade_notch'])), oadr, debt, gap, lost
        ])

    write_markdown_table(
        TABLES_DIR / "table3_downgrade_timeline.md",
        "Table 3: Sovereign Downgrade Events (1990-2024)",
        headers, rows,
        notes="Events where S&P rating decreased from prior year. Lost Safe = fell below AA-."
    )


def figure1_safe_count(df):
    """Figure 1: N safe issuers and safe supply ratio over time."""
    print("\n  Building Figure 1: Safe Issuer Count ...")

    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        print("    matplotlib not available, skipping figure")
        return

    yearly = df.groupby('year').agg(
        n_safe=('safe_issuer', 'sum'),
    ).reset_index()

    # Safe supply ratio if available
    if 'safe_supply_ratio' in df.columns:
        ssr = df.groupby('year')['safe_supply_ratio'].first().reset_index()
        yearly = yearly.merge(ssr, on='year', how='left')
        has_ssr = yearly['safe_supply_ratio'].notna().any()
    else:
        has_ssr = False

    fig, ax1 = plt.subplots(figsize=(10, 5))

    ax1.bar(yearly['year'], yearly['n_safe'], color='steelblue', alpha=0.7, label='N safe issuers')
    ax1.set_xlabel('Year', fontsize=12)
    ax1.set_ylabel('Number of Safe Issuers (AA- or above)', fontsize=11, color='steelblue')
    ax1.tick_params(axis='y', labelcolor='steelblue')
    ax1.set_ylim(0, yearly['n_safe'].max() + 3)

    if has_ssr:
        ax2 = ax1.twinx()
        ax2.plot(yearly['year'], yearly['safe_supply_ratio'], 'r-o', markersize=3,
                 label='Safe supply ratio')
        ax2.set_ylabel('Safe Supply / Global GDP', fontsize=11, color='red')
        ax2.tick_params(axis='y', labelcolor='red')

    ax1.set_title('Panel A: Safe Issuers and Supply Over Time', fontsize=13)
    fig.tight_layout()
    fig.savefig(FIGURES_DIR / "figure1_safe_count.png", dpi=150, bbox_inches='tight')
    plt.close()
    print(f"    Saved: {FIGURES_DIR / 'figure1_safe_count.png'}")


def figure2_scatter(df):
    """Figure 2: OADR vs rating_numeric scatter."""
    print("\n  Building Figure 2: OADR vs Rating Scatter ...")

    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        print("    matplotlib not available, skipping figure")
        return

    sub = df.dropna(subset=['old_dep', 'rating_numeric']).copy()

    fig, ax = plt.subplots(figsize=(10, 6))

    # Non-downgrade points
    normal = sub[sub['downgrade_any'] == 0]
    ax.scatter(normal['old_dep'] * 100, normal['rating_numeric'],
               alpha=0.2, s=15, c='steelblue', label='No downgrade')

    # Downgrade points
    dg = sub[sub['downgrade_any'] == 1]
    ax.scatter(dg['old_dep'] * 100, dg['rating_numeric'],
               alpha=0.9, s=60, c='red', marker='X', label='Downgrade event', zorder=5)

    # Label downgrade events
    for _, row in dg.iterrows():
        ax.annotate(f"{row['iso3']} {int(row['year'])}",
                    (row['old_dep'] * 100, row['rating_numeric']),
                    fontsize=7, alpha=0.8, xytext=(5, 5),
                    textcoords='offset points')

    ax.axhline(y=18, color='orange', linestyle='--', alpha=0.5, label='AA- threshold')
    ax.set_xlabel('Old-Age Dependency Ratio (%)', fontsize=12)
    ax.set_ylabel('S&P Rating (numeric)', fontsize=12)
    ax.set_title('Figure 2: OADR vs. Sovereign Rating', fontsize=13)
    ax.legend(fontsize=9)

    # Add rating labels on right axis
    yticks = [21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11]
    ylabels = [REVERSE_SCALE.get(y, '') for y in yticks]
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{y} ({l})" for y, l in zip(yticks, ylabels)])

    fig.tight_layout()
    fig.savefig(FIGURES_DIR / "figure2_oadr_rating_scatter.png", dpi=150, bbox_inches='tight')
    plt.close()
    print(f"    Saved: {FIGURES_DIR / 'figure2_oadr_rating_scatter.png'}")


def main():
    print("=" * 70)
    print("PHASE 2: Descriptive Statistics")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded cliff_panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    table1_transitions(df)
    table2_summary_by_category(df)
    table3_downgrade_timeline(df)
    figure1_safe_count(df)
    figure2_scatter(df)

    print("\n" + "=" * 70)
    print("Phase 2 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
